import torch
from torch import optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms

import torch.nn as nn
from torch.autograd import Variable
from mmd import mix_rbf_mmd2, rbf_mmd2, batched_rbf_mmd2


def get_MMD_values_uneven(D_Xs, D_Ys, V_X, V_Y, netD = None, device=torch.device('cuda'), sigma_list = [1, 2, 5, 10], squared=False, batch_size=1024):
    """
    Based on an implementation of MMD2 that enables different-sized inputs. 
    Also leverages the sigma list.
    """
    results = []
    V_X = V_X.to(device)

    rand_indx = torch.randperm(len(V_X))
    permuted_V_X = V_X[rand_indx] # permutation is used because later batching is applied AND possibly only a subset of V_X is used 
    for D_X in D_Xs:
        D_X = D_X.to(device)

        min_len = min(V_X.size(0), D_X.size(0))
        ref = permuted_V_X[:min_len] 
        D_X = D_X[:min_len] 

        if torch.cuda.device_count() > 1:          
            device = torch.device(f'cuda:1') 

        # MMD2 = rbf_mmd2(D_X, ref, sigma_list) # this rbf_mmd2 allows different-sized inputs
        MMD2 = batched_rbf_mmd2(D_X, ref, sigma_list, device=device, batch_size=batch_size) # use a batched version of rbf_mmd2 to avoid OOM error
        if squared:
            results.append(-MMD2.item())
        else:
            # take the square root
            results.append(-torch.sqrt(max(1e-6, MMD2)).item())
    return results


from sklearn.utils import resample
def get_mix_reference(D_Xs, generated_reference, pct, device=torch.device('cuda')):
    '''
    composition of the returned mix_reference: (1-pct) * D_N + pct * generated. So, pct is lambda in paper.
    '''

    D_N = torch.cat(D_Xs).to(device)
    generated_reference = generated_reference.to(device)
    m = min(len(D_N), len(generated_reference))
    D_N_sub = resample(D_N, n_samples=int((1 - pct)*m))
    generated_reference_sub = resample(generated_reference, n_samples=int(pct * m))

    # print(f"intersections: {sum([(_ in D_N) for _ in D_N_sub])} / {len(D_N)},  {sum([ (_ in generated_reference) for _ in generated_reference_sub])} / {len(generated_reference)}")
    reference = torch.cat([D_N_sub, generated_reference_sub])
    # print(f"For {pct} of generated, {D_N_sub.shape}, {generated_reference_sub.shape}, and the mixed reference shape is : {reference.shape}")    
    return reference

from collections import defaultdict
from copy import deepcopy

from utils import set_deterministic, save_results
from data_utils import assign_data, _get_loader

import torch

from tqdm import tqdm
import argparse
from os.path import join as oj

baseline = 'Ours'

class options:
    cuda = torch.cuda.is_available()
    batch_size = 256
    image_size = 32
    n_filters = 100
    epochs = 100
    mmd_batch_size = 1024


if __name__ == '__main__':
    
    parser = argparse.ArgumentParser(description='Process which dataset to run')
    parser.add_argument('-N', '--N', help='Number of data vendors.', type=int, required=True, default=5)
    parser.add_argument('-m', '--size', help='Size of sample datasets.', type=int, required=True, default=1500)
    parser.add_argument('-P', '--dataset', help='Pick the dataset to run.', type=str, required=True)
    parser.add_argument('-Q', '--Q_dataset', help='Pick the Q dataset.', type=str, required=False, choices=['normal', 'EMNIST', 'FaMNIST', 'CIFAR100' , 'CreditCard', 'UGR16'])
    parser.add_argument('-n_t', '--n_trials', help='Number of trials.', type=int, default=5)
    parser.add_argument('-nh', '--not_huber', help='Not with huber, meaning with other types of specified heterogeneity.', action='store_true')
    parser.add_argument('-het', '--heterogeneity', help='Type of heterogeneity.', type=str, default='normal', choices=['normal', 'label', 'classimbalance'])
    parser.add_argument('-kde', dest='gmm', help='Whether to use KDE for generator distribution. Only applicable to CreditCard or TON dataset.', action='store_false')
    parser.add_argument('-gmm', dest='gmm', help='Whether to use GMM for generator distribution. Only applicable to CreditCard or TON dataset.', action='store_true')


    parser.add_argument('-UR', dest='under_report', help='Whether to check IC for under-reporting.', action='store_true')
    parser.add_argument('-MR', dest='under_report', help='Whether to check IC for mis-reporting.', action='store_false')

    # parser.add_argument('-nocuda', dest='cuda', help='Not to use cuda even if available.', action='store_false')
    # parser.add_argument('-cuda', dest='cuda', help='Use cuda if available.', action='store_true')

    cmd_args = parser.parse_args()
    print(cmd_args)

    dataset = cmd_args.dataset
    Q_dataset = cmd_args.Q_dataset
    N = cmd_args.N
    size = cmd_args.size
    n_trials = cmd_args.n_trials
    not_huber = cmd_args.not_huber
    heterogeneity = cmd_args.heterogeneity
    use_GMM = cmd_args.gmm
    under_report = cmd_args.under_report
    
    if dataset == 'MNIST': 
        options.mmd_batch_size = 256
    elif dataset == 'CIFAR10':
        options.mmd_batch_size = 256

    print(f"----- Running experiment for {baseline} IC test -----")

    set_deterministic()
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    values_over_trials, values_hat_over_trials = [], []

    # the baseline of MMD2: 
    # when a reference is given, use the reference, but the valuation is based on MMD squared
    values_mmd2_over_trials, values_hat_mmd2_over_trials = [], []
    
    # dishonest report values: either mis-report or under-report
    dishonest_rp_values_over_trials, dishonest_rp_values_hat_over_trials = defaultdict(list), defaultdict(list)
    dishonest_rp_values_hat_loo_over_trials = defaultdict(list)

    dishonest_rp_mmd2_over_trials, dishonest_rp_mmd2_hat_over_trials = defaultdict(list), defaultdict(list)
    # dishonest values over trials are defaultdicts where the key is the index of the vendor, and the value is a list of the values over trials
    # example dishonest_rp_values_over_trials[0] is the list of values for vendor 0 with dishonest reporting over the trials
    

    for _ in tqdm(range(n_trials), desc =f'A total of {n_trials} trials.'):
        # raw data
        D_Xs, D_Ys, V_X, V_Y, labels = assign_data(N, size, dataset, Q_dataset, not_huber, heterogeneity)

        reference = torch.cat(D_Xs)
        netD = None
        print(f"The shape of reference (via union): {reference.shape}.")
        try:
            MMD_values = get_MMD_values_uneven(D_Xs, None, V_X, None, netD, device, batch_size=options.batch_size)
            values_over_trials.append(MMD_values)

            MMD2_values = get_MMD_values_uneven(D_Xs, None, V_X, None, netD, device, squared=True, batch_size=options.batch_size)
            values_mmd2_over_trials.append(MMD2_values)

            MMD2_values_hat = get_MMD_values_uneven(D_Xs, None, reference, None, netD, device, squared=True, batch_size=options.batch_size)
            values_hat_mmd2_over_trials.append(MMD2_values_hat)

            MMD_values_hat = get_MMD_values_uneven(D_Xs, None, reference, None, netD, device, batch_size=options.batch_size)
            values_hat_over_trials.append(MMD_values_hat)

            for i, D_X in enumerate(D_Xs):
                if under_report:
                    under_report_pct = 0.8
                    D_X_hat = resample(D_X, n_samples = int(under_report_pct * len(D_X)))

                    # print(f"Original length of D_X is {len(D_X)}, size of subsample is {len(D_X_hat)}")
                    D_Xs_ = deepcopy(D_Xs)
                    D_Xs_[i] = D_X_hat
                
                else: # mis-reporting
                    mis_report_noise_sigma = 0.2
                    D_X_tilde = D_X + torch.randn_like(D_X) * mis_report_noise_sigma

                    D_Xs_ = deepcopy(D_Xs)
                    D_Xs_[i] = D_X_tilde
        
                reference_hat = torch.cat(D_Xs_)
                # print(f"The shape of reference hat (via union): {reference_hat.shape}.")

                reference_hat_LOO = torch.cat([D_X_ for j, D_X_ in enumerate(D_Xs_) if i != j ])
                # print(f"The shape of reference LOO (via union): {reference_hat_LOO.shape}.")


                MMD_values = get_MMD_values_uneven(D_Xs_, None, V_X, None, netD, device, batch_size=options.batch_size)
                dishonest_rp_values_over_trials[i].append(MMD_values)

                MMD2_values = get_MMD_values_uneven(D_Xs_, None, V_X, None, netD, device, squared=True, batch_size=options.batch_size)
                dishonest_rp_mmd2_over_trials[i].append(MMD2_values)

                MMD2_values_hat = get_MMD_values_uneven(D_Xs_, None, reference_hat, None, None, device, squared=True, batch_size=options.batch_size)
                dishonest_rp_mmd2_hat_over_trials[i].append(MMD2_values_hat)

                MMD_values_hat = get_MMD_values_uneven(D_Xs_, None, reference_hat, None, netD, device, batch_size=options.batch_size)
                dishonest_rp_values_hat_over_trials[i].append(MMD_values_hat)
            
                MMD_values_hat = get_MMD_values_uneven(D_Xs_, None, reference_hat_LOO, None, netD, device, batch_size=options.batch_size)
                dishonest_rp_values_hat_loo_over_trials[i].append(MMD_values_hat)

            print(f"Valuation with mis/under-reporting is complete.")

        except RuntimeError as e: #Cuda Memory issue
            if str(e).startswith('CUDA out of memory.'):
                print('CUDA out of memory: Using CPU for MMD2 computation.')
                cpu = torch.device('cpu')
            raise Exception

    if dataset == 'CreditCard' or dataset ==  'TON':
        baseline = baseline + "_GMM" if use_GMM else baseline +  '_KDE'

    results = {'values_over_trials': values_over_trials, 'values_hat_over_trials': values_hat_over_trials, 
    'N':N, 'size':size, 'n_trials': n_trials, 'isHuber':not not_huber, 'heterogeneity': heterogeneity, 'use_GMM': use_GMM}

    if under_report:
        exp_name = f'{dataset}_vs_{Q_dataset}-N{N} m{size} n_trials{n_trials} IC-test under-report'
    else:
        exp_name = f'{dataset}_vs_{Q_dataset}-N{N} m{size} n_trials{n_trials} IC-test mis-report'
    exp_name = oj('IC', exp_name)

    for i in range(N):
        results[f"dishonest_rp_values_over_trials-{i}"] = dishonest_rp_values_over_trials[i]
        results[f"dishonest_rp_values_hat_over_trials-{i}"] = dishonest_rp_values_hat_over_trials[i]
        results[f"dishonest_rp_values_hat_loo_over_trials-{i}"] = dishonest_rp_values_hat_loo_over_trials[i]

    save_results(baseline=baseline, exp_name=exp_name, **results)

    # For MMD squared w.r.t. half mix reference
    results = {'values_over_trials': values_mmd2_over_trials, 'values_hat_mmd2_over_trials': values_hat_mmd2_over_trials,
               'N':N, 'size':size, 'n_trials': n_trials, 'isHuber': not not_huber, 'heterogeneity': heterogeneity, 'use_GMM': use_GMM}
    for i in range(N):
        results[f"dishonest_rp_mmd2_over_trials-{i}"] = dishonest_rp_mmd2_over_trials[i]
        results[f"dishonest_rp_mmd2_hat_over_trials-{i}"] = dishonest_rp_mmd2_hat_over_trials[i]
    
    save_results(baseline='MMD_sq', exp_name=exp_name, **results)

